{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wm0ezXfhdp2P"
      },
      "source": [
        "# Convert Tensorflow model checkpoints to saved model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KAy1ciItdrDB"
      },
      "source": [
        "Given the checkpoints exported from Tensorflow model training, our goal is to convert those checkpoints into saved model for inference purpose.\u003cbr\u003e\n",
        "Checkpoints is a binary file which contains all the values of the weights, biases, gradients and all the other variables saved. This file has an extension .ckpt. Checkpoints do not contain any description of the computation defined by the model and thus are typically only useful when source code that will use the saved parameter values is available.\u003cbr\u003e\n",
        "A saved model contains a complete tensorflow program, including trained parameters and computation. It does not require the original model building code to run, which makes it useful for sharing or deploying with TFLite, tensorflow.js, Tensorflow Serving or Tensorflow Hub."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Lg4uH43Qds0q"
      },
      "source": [
        "**Note** - We also assume that the script will be used as a Google Colab notebook. But this can be changed according to the needs of users. They can modify this in case they are working on their local workstation, remote server or any other database. This colab notebook can be changed to a regular jupyter notebook running on a local machine according to the need of the users."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BofYg406d4LV"
      },
      "source": [
        "## Import libraries \u0026 clone the TF model directory"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oegDgL7yaaAq"
      },
      "outputs": [],
      "source": [
        "# install model-garden official and RESTART RUNTIME of the colab\n",
        "!pip install tf-models-official"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fMROz-xXdx6c"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "from google.colab import drive\n",
        "import yaml\n",
        "import tensorflow as tf"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "executionInfo": {
          "elapsed": 1197,
          "status": "ok",
          "timestamp": 1659384634603,
          "user": {
            "displayName": "Umair Sabir",
            "userId": "06940594206388957365"
          },
          "user_tz": 420
        },
        "id": "gz1ajpHgeAJT",
        "outputId": "1187e44e-82eb-4be1-8adc-1b50f6d7d0ed"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount(\"/content/gdrive\", force_remount=True).\n",
            "ln: failed to create symbolic link '/mydrive/My Drive': File exists\n",
            "Successful\n"
          ]
        }
      ],
      "source": [
        "# use this if your model and data are stored in the google drive\n",
        "drive.mount('/content/gdrive')\n",
        "\n",
        "try:\n",
        "  !ln -s /content/gdrive/My\\ Drive/ /mydrive\n",
        "  print('Successful')\n",
        "except Exception as e:\n",
        "  print(e)\n",
        "  print('Not successful')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "executionInfo": {
          "elapsed": 290,
          "status": "ok",
          "timestamp": 1659384648663,
          "user": {
            "displayName": "Umair Sabir",
            "userId": "06940594206388957365"
          },
          "user_tz": 420
        },
        "id": "rRGalo90e2my",
        "outputId": "f0305c8b-cf06-4637-a83c-05f5471313e7"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "fatal: destination path 'models' already exists and is not an empty directory.\n"
          ]
        }
      ],
      "source": [
        "# Clone the tensorflow models repository\n",
        "!git clone --depth 1 https://github.com/tensorflow/models"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "executionInfo": {
          "elapsed": 189,
          "status": "ok",
          "timestamp": 1659384681521,
          "user": {
            "displayName": "Umair Sabir",
            "userId": "06940594206388957365"
          },
          "user_tz": 420
        },
        "id": "HalXsX7BqdyX",
        "outputId": "cb5555e4-0b77-4036-9230-1c01fcf1afaf"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "/content/models\n"
          ]
        }
      ],
      "source": [
        "%cd models"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8o7QzpHOsFHS"
      },
      "source": [
        "## **MUST CHANGE** - Define the parameters according to your need"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xak5WwMXppDF"
      },
      "outputs": [],
      "source": [
        "# this parameter depends on the backbone you will be using. In our \n",
        "# case we used resnet backbone\n",
        "EXPERIMENT_TYPE = 'maskrcnn_resnetfpn_coco' #@param {type:\"string\"}\n",
        "\n",
        "# path to the folder where all the files and checkpoints after model training \n",
        "# are exported to\n",
        "CHECKPOINT_PATH = '/mydrive/plastics_model/version_1/' #@param {type:\"string\"}\n",
        "\n",
        "# path where the saved model will be exported to\n",
        "EXPORT_DIR_PATH = '/mydrive/plastics_model/experiment/' #@param {type:\"string\"}\n",
        "\n",
        "# config files are always stored with the checkpoints\n",
        "CONFIG_FILE= CHECKPOINT_PATH + 'params.yaml'"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ddZzH5KhqSAy"
      },
      "outputs": [],
      "source": [
        "# config files are always stored with the checkpoints\n",
        "# read the params.yaml file in order to get the height and width of an image\n",
        "with open(CONFIG_FILE) as f:\n",
        "    my_dict = yaml.safe_load(f)\n",
        "\n",
        "HEIGHT = my_dict['task']['model']['input_size'][0]\n",
        "WIDTH = my_dict['task']['model']['input_size'][1]\n",
        "print(HEIGHT, WIDTH)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZWrbqHqJt947"
      },
      "source": [
        "## calling the function to convert checkpoints to saved model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "k3wuYiVte4t_"
      },
      "outputs": [],
      "source": [
        "# run the conversion command\n",
        "!python -m official.vision.serving.export_saved_model --experiment=$EXPERIMENT_TYPE \\\n",
        "                   --export_dir=$EXPORT_DIR_PATH \\\n",
        "                   --checkpoint_path=$CHECKPOINT_PATH \\\n",
        "                   --batch_size=1 \\\n",
        "                   --input_image_size=$HEIGHT,$WIDTH \\\n",
        "                   --input_type=tflite \\\n",
        "                   --config_file=$CONFIG_FILE"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jQb13BbW78bs"
      },
      "source": [
        "# Convert saved model to TF Lite model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QEKaddZ58Ab6"
      },
      "source": [
        "Given the saved model after Tensorflow model training, our goal is to convert saved model to TFLite for inference purpose on edge devices. "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZXpSX-w_8A1c"
      },
      "source": [
        "Tensorflow Lite is a set of tools that enables on-device machine learning by helping developers run their models on mobile, embedded and edge devices."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "20sV-bx59GRD"
      },
      "source": [
        "## **MUST CHANGE** - Define the parameters according to your need"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "icoDtIin9REv"
      },
      "outputs": [],
      "source": [
        "# path where the tflite model will be written with its name\n",
        "TFLITE_PATH = '/mydrive/gtech/MRFs/Recykal/Latest_sharing_by_sanket/Google_Recykal/Taxonomy_version_2/model_version_1/plastics_model/tflite_fan/model.tflite' #@param {type:\"string\"}\n",
        "\n",
        "# path where saved model parameters are saved\n",
        "SAVED_MODEL_DIR = EXPORT_DIR_PATH + '/saved_model/'"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YE1Z13xs9NtJ"
      },
      "source": [
        "## conversion of saved model to tflite"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1LY9yoUP6Gr4"
      },
      "outputs": [],
      "source": [
        "converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir=SAVED_MODEL_DIR)  \n",
        "tflite_model = converter.convert() \n",
        "with open(TFLITE_PATH, 'wb') as f:\n",
        "  f.write(tflite_model)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "checkpoints_to_saved_model_to_tflite.ipynb",
      "provenance": []
    },
    "gpuClass": "standard",
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
